/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/* !
 * \file softmax_flashv2_impl.h
 * \brief
 */
#ifndef IMPL_ACTIVATION_SOFTMAX_C310_SOFTMAX_FLASHV2_IMPL_H
#define IMPL_ACTIVATION_SOFTMAX_C310_SOFTMAX_FLASHV2_IMPL_H

#include "softmax_common_impl.h"

namespace AscendC {

template <typename T1, typename T2, bool isOutputReduceMax = false>
__simd_vf__ inline void SoftmaxFlashV2M1NDUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* reduceMaxUb,
    __local_mem__ T2* expSumUb, __local_mem__ T2* inExpSumUb, __local_mem__ T2* maxUb, __local_mem__ T2* inMaxUb,
    __local_mem__ T1* srcUb, __local_mem__ T1* expMaxUb, const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = originalSrcShape.m;
    uint16_t srcK = tiling.srcK;
    uint16_t originK = originalSrcShape.k;
    uint16_t repeatTimes = CeilDivision(originK, FLOAT_REPEAT_SIZE);
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOnePt = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL1>();
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> expMaxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> dstVreg;

    for (uint16_t i = 0; i < srcM; ++i) {
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < repeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);
        }
        MicroAPI::ReduceMax(maxVreg, maxVreg, pregFull);
        if constexpr (isOutputReduceMax) {
            StoreIfNeedCastM1<T2>(reduceMaxUb + i, maxVreg, pregOnePt);
        }
        LoadIfNeedCastM1<T2>(tmpVreg, inMaxUb + i, pregOnePt);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregOnePt);
        StoreIfNeedCastM1<T2>(maxUb + i, maxVreg, pregOnePt);

        MicroAPI::FusedExpSub(expMaxVreg, tmpVreg, maxVreg, pregOnePt);
        StoreIfNeedCastM1<T1>(expMaxUb + i, expMaxVreg, pregOnePt);

        Duplicate(sumVreg, 0);
        Duplicate(maxVreg, maxVreg, pregFull);
        for (uint16_t j = 0; j < repeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregFull);
            StoreIfNeedCast<T1>(dstUb + i * srcK + j * FLOAT_REPEAT_SIZE, tmpVreg, pregFull);
            MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        }
        MicroAPI::ReduceSum(sumVreg, sumVreg, pregFull);
        LoadIfNeedCastM1<T2>(tmpVreg, inExpSumUb + i, pregOnePt);
        MicroAPI::Mul(tmpVreg, expMaxVreg, tmpVreg, pregOnePt);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregOnePt);
        StoreIfNeedCastM1<T2>(expSumUb + i, sumVreg, pregOnePt);
    }
}

template <typename T1, typename T2, bool isOutputReduceMax = false>
__aicore__ inline void SoftmaxFlashV2M1NDUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& outReduceMax,
    const LocalTensor<T2>& expSumTensor, const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor,
    const LocalTensor<T1>& expMaxTensor, const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor,
    const LocalTensor<float>& workLocal, const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* reduceMaxUb = (__local_mem__ T2*)outReduceMax.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* inExpSumUb = (__local_mem__ T2*)inExpSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T2* inMaxUb = (__local_mem__ T2*)inMaxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T1* expMaxUb = (__local_mem__ T1*)expMaxTensor.GetPhyAddr();

    SoftmaxFlashV2M1NDUpdateVFImpl<T1, T2, isOutputReduceMax>(dstUb, reduceMaxUb, expSumUb, inExpSumUb, maxUb, inMaxUb, srcUb,
        expMaxUb, originalSrcShape, tiling);
}

template <typename T1, typename T2, bool isOutputReduceMax = false>
__simd_vf__ inline void SoftmaxFlashV2M1NDWithTailUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* reduceMaxUb,
    __local_mem__ T2* expSumUb, __local_mem__ T2* inExpSumUb, __local_mem__ T2* maxUb, __local_mem__ T2* inMaxUb,
    __local_mem__ T1* srcUb, __local_mem__ T1* expMaxUb, const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = originalSrcShape.m;
    uint16_t srcK = tiling.srcK;
    uint16_t originK = originalSrcShape.k;
    uint16_t repeatTimes = CeilDivision(originK, FLOAT_REPEAT_SIZE);
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregCnt;
    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOnePt = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL1>();
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> expMaxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> minVreg;
    MicroAPI::RegTensor<float> dstVreg;

    Duplicate(minVreg, notNum.f);
    for (uint16_t i = 0; i < srcM; ++i) {
        uint32_t sreg = originK;
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < static_cast<uint16_t>(repeatTimes - 1); ++j) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregCnt);
        }
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + (repeatTimes - 1) * FLOAT_REPEAT_SIZE, pregFull);
        MicroAPI::Select(srcVreg, srcVreg, minVreg, pregCnt);
        MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);

        MicroAPI::ReduceMax(maxVreg, maxVreg, pregFull);
        if constexpr (isOutputReduceMax) {
            StoreIfNeedCastM1<T2>(reduceMaxUb + i, maxVreg, pregOnePt);
        }
        LoadIfNeedCastM1<T2>(tmpVreg, inMaxUb + i, pregOnePt);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregOnePt);
        StoreIfNeedCastM1<T2>(maxUb + i, maxVreg, pregOnePt);

        MicroAPI::FusedExpSub(expMaxVreg, tmpVreg, maxVreg, pregOnePt);
        StoreIfNeedCastM1<T1>(expMaxUb + i, expMaxVreg, pregOnePt);

        Duplicate(sumVreg, 0);
        Duplicate(maxVreg, maxVreg, pregFull);
        sreg = originK;
        for (uint16_t j = 0; j < repeatTimes; ++j) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregCnt);
            StoreIfNeedCast<T1>(dstUb + i * srcK + j * FLOAT_REPEAT_SIZE, tmpVreg, pregCnt);
            MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        }
        MicroAPI::ReduceSum(sumVreg, sumVreg, pregFull);
        LoadIfNeedCastM1<T2>(tmpVreg, inExpSumUb + i, pregOnePt);
        MicroAPI::Mul(tmpVreg, expMaxVreg, tmpVreg, pregOnePt);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregOnePt);
        StoreIfNeedCastM1<T2>(expSumUb + i, sumVreg, pregOnePt);
    }
}

template <typename T1, typename T2, bool isOutputReduceMax = false>
__aicore__ inline void SoftmaxFlashV2M1NDWithTailUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& outReduceMax,
    const LocalTensor<T2>& expSumTensor, const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor,
    const LocalTensor<T1>& expMaxTensor, const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor,
    const LocalTensor<float>& workLocal, const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* reduceMaxUb = (__local_mem__ T2*)outReduceMax.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* inExpSumUb = (__local_mem__ T2*)inExpSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T2* inMaxUb = (__local_mem__ T2*)inMaxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T1* expMaxUb = (__local_mem__ T1*)expMaxTensor.GetPhyAddr();

    SoftmaxFlashV2M1NDWithTailUpdateVFImpl<T1, T2, isOutputReduceMax>(dstUb, reduceMaxUb, expSumUb, inExpSumUb, maxUb, inMaxUb, srcUb,
        expMaxUb, originalSrcShape, tiling);
}

template <typename T1, typename T2>
__simd_vf__ inline void SoftmaxFlashV2NZNoUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* expSumUb,
    __local_mem__ T2* maxUb, __local_mem__ T1* srcUb, __local_mem__ T2* tmpUb, __local_mem__ float* workUb,
    __local_mem__ float* expUb, const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t originK = originalSrcShape.k;
    uint16_t originM = originalSrcShape.m;
    uint16_t dataBlock = srcM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t mRepeatTimes = CeilDivision(dataBlock, FLOAT_REPEAT_SIZE);
    uint16_t kRepeatTimes = originK / SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t VcgFoldRepeat = CeilDivision(srcM, FLOAT_REPEAT_SIZE);
    uint16_t e2bRep = srcM / DEFAULT_BLK_NUM;
    uint16_t dtypeRepStride = IsSameType<T2, half>::value ? HALF_REPEAT_SIZE : FLOAT_REPEAT_SIZE;
    uint16_t dtypeBlkStride = dtypeRepStride / DEFAULT_BLK_NUM;

    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregCnt;
    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL8>();
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> dstVreg;
    MicroAPI::RegTensor<T2> castReg;

    // reducemax
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);
        }
        MicroAPI::ReduceMaxWithDataBlock(maxVreg, maxVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, maxVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            maxVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregFull);
        StoreIfNeedCast<T2>(tmpUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
            workUb + i * HALF_REPEAT_SIZE, maxVreg, maxVreg, pregFull);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    uint32_t sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<T2>(sreg);
        LoadE2B<T2>(castReg, tmpUb + i * DEFAULT_BLK_NUM);
        MicroAPI::DataCopy(maxUb + i * dtypeRepStride, castReg, pregCnt);
    }

    for (uint16_t j = 0; j < kRepeatTimes; ++j) {
        sreg = originM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
        for (uint16_t i = 0; i < mRepeatTimes; ++i) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            LoadE2B<float>(maxVreg, workUb + i * DEFAULT_BLK_NUM);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregFull);
            StoreIfNeedCast<T1>(dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            }
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    // reducesum
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(sumVreg, 0);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(srcVreg, expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            } else {
                MicroAPI::DataCopy(srcVreg, dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            }
            MicroAPI::Add(sumVreg, sumVreg, srcVreg, pregFull);
        }
        MicroAPI::ReduceSumWithDataBlock(sumVreg, sumVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, sumVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            sumVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        StoreIfNeedCast<T2>(tmpUb + i * FLOAT_REPEAT_SIZE, sumVreg, pregFull);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<T2>(sreg);
        LoadE2B<T2>(castReg, tmpUb + i * DEFAULT_BLK_NUM);
        MicroAPI::DataCopy(expSumUb + i * dtypeRepStride, castReg, pregCnt);
    }
}

template <typename T1, typename T2>
__aicore__ inline void SoftmaxFlashV2NZNoUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t VcgFoldRepeat = CeilDivision(srcM, FLOAT_REPEAT_SIZE);
    uint16_t offset = CeilDivision(srcM * SOFTMAX_COMPUTE_DIM, HALF_REPEAT_SIZE) * HALF_REPEAT_SIZE;

    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T2* tmpUb = (__local_mem__ T2*)workLocal.GetPhyAddr();
    __local_mem__ float* workUb = (__local_mem__ float*)workLocal.GetPhyAddr(VcgFoldRepeat * FLOAT_REPEAT_SIZE);
    __local_mem__ float* expUb = (__local_mem__ float*)workLocal.GetPhyAddr(VcgFoldRepeat * FLOAT_REPEAT_SIZE + offset);

    SoftmaxFlashV2NZNoUpdateVFImpl<T1, T2>(dstUb, expSumUb, maxUb, srcUb, tmpUb, workUb, expUb, originalSrcShape, tiling);
}

template <typename T1, typename T2>
__simd_vf__ inline void SoftmaxFlashV2NZWithTailNoUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* expSumUb,
    __local_mem__ T2* maxUb, __local_mem__ T1* srcUb, __local_mem__ T2* tmpUb, __local_mem__ float* workUb,
    __local_mem__ float* expUb, const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t originK = originalSrcShape.k;
    uint16_t originM = originalSrcShape.m;

    uint16_t dataBlock = srcM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t mRepeatTimes = CeilDivision(dataBlock, FLOAT_REPEAT_SIZE);
    uint16_t kRepeatTimes = originK / SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t VcgFoldRepeat = CeilDivision(srcM, FLOAT_REPEAT_SIZE);
    uint16_t e2bRep = srcM / DEFAULT_BLK_NUM;
    uint16_t dtypeRepStride = IsSameType<T2, half>::value ? HALF_REPEAT_SIZE : FLOAT_REPEAT_SIZE;
    uint16_t dtypeBlkStride = dtypeRepStride / DEFAULT_BLK_NUM;
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregDst;
    MicroAPI::MaskReg pregkTail = MicroAPI::MoveMask<uint32_t>();
    MicroAPI::MaskReg pregCnt;
    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL8>();
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> minVreg;
    MicroAPI::RegTensor<float> dstVreg;
    MicroAPI::RegTensor<T2> castReg;

    // reducemax
    Duplicate(minVreg, notNum.f);
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);
        }
        LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, pregFull);
        MicroAPI::Select(srcVreg, srcVreg, minVreg, pregkTail);
        MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);

        MicroAPI::ReduceMaxWithDataBlock(maxVreg, maxVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, maxVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            maxVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregFull);
        StoreIfNeedCast<T2>(tmpUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
            workUb + i * HALF_REPEAT_SIZE, maxVreg, maxVreg, pregFull);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    uint32_t sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<T2>(sreg);
        LoadE2B<T2>(castReg, tmpUb + i * DEFAULT_BLK_NUM);
        MicroAPI::DataCopy(maxUb + i * dtypeRepStride, castReg, pregCnt);
    }

    for (uint16_t j = 0; j < kRepeatTimes; ++j) {
        sreg = originM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
        for (uint16_t i = 0; i < mRepeatTimes; ++i) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            LoadE2B<float>(maxVreg, workUb + i * DEFAULT_BLK_NUM);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregFull);
            StoreIfNeedCast<T1>(dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            }
        }
    }
    sreg = originM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadE2B<float>(maxVreg, workUb + i * DEFAULT_BLK_NUM);
        LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, pregFull);
        MicroAPI::MaskAnd(pregDst, pregkTail, pregCnt, pregFull);
        MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregDst);
        StoreIfNeedCast<T1>(dstUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, tmpVreg, pregDst);
        if constexpr (sizeof(T1) == 2) {
            MicroAPI::DataCopy(expUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, tmpVreg, pregDst);
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    // reducesum
    Duplicate(minVreg, 0);
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(sumVreg, 0);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(srcVreg, expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            } else {
                MicroAPI::DataCopy(srcVreg, dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            }
            MicroAPI::Add(sumVreg, sumVreg, srcVreg, pregFull);
        }
        if constexpr (sizeof(T1) == 2) {
            MicroAPI::DataCopy(srcVreg, expUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock);
        } else {
            MicroAPI::DataCopy(srcVreg, dstUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock);
        }
        MicroAPI::Select(srcVreg, srcVreg, minVreg, pregkTail);
        MicroAPI::Add(sumVreg, sumVreg, srcVreg, pregFull);

        MicroAPI::ReduceSumWithDataBlock(sumVreg, sumVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, sumVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            sumVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        StoreIfNeedCast<T2>(tmpUb + i * FLOAT_REPEAT_SIZE, sumVreg, pregFull);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<T2>(sreg);
        LoadE2B<T2>(castReg, tmpUb + i * DEFAULT_BLK_NUM);
        MicroAPI::DataCopy(expSumUb + i * dtypeRepStride, castReg, pregCnt);
    }
}

template <typename T1, typename T2>
__aicore__ inline void SoftmaxFlashV2NZWithTailNoUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t originK = originalSrcShape.k;
    uint16_t VcgFoldRepeat = CeilDivision(srcM, FLOAT_REPEAT_SIZE);
    uint16_t offset = CeilDivision(srcM * SOFTMAX_COMPUTE_DIM, HALF_REPEAT_SIZE) * HALF_REPEAT_SIZE;
    uint64_t mask[2] = {0, 0};
    CreateSpecialFormatMask(mask[0], originK % SOFTMAX_SHAPE_NZ_BASIC_COUNT, FLOAT_REPEAT_SIZE / SOFTMAX_SHAPE_NZ_BASIC_COUNT);
    SetVectorMask<uint32_t>(mask[1], mask[0]);

    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T2* tmpUb = (__local_mem__ T2*)workLocal.GetPhyAddr();
    __local_mem__ float* workUb = (__local_mem__ float*)workLocal.GetPhyAddr(VcgFoldRepeat * FLOAT_REPEAT_SIZE);
    __local_mem__ float* expUb = (__local_mem__ float*)workLocal.GetPhyAddr(VcgFoldRepeat * FLOAT_REPEAT_SIZE + offset);

    SoftmaxFlashV2NZWithTailNoUpdateVFImpl<T1, T2>(dstUb, expSumUb, maxUb, srcUb, tmpUb, workUb,
        expUb, originalSrcShape, tiling);
}

template <typename T1, typename T2>
__simd_vf__ inline void SoftmaxFlashV2NZUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* expSumUb, __local_mem__ T2* inExpSumUb,
    __local_mem__ T2* maxUb, __local_mem__ T2* inMaxUb, __local_mem__ T1* srcUb, __local_mem__ T1* expMaxUb,
    __local_mem__ float* expMaxF32Ub, __local_mem__ float* workUb, __local_mem__ float* tmpUb, __local_mem__ float* expUb,
    const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t originK = originalSrcShape.k;
    uint16_t originM = originalSrcShape.m;

    uint16_t dataBlock = srcM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t mRepeatTimes = CeilDivision(dataBlock, FLOAT_REPEAT_SIZE);
    uint16_t kRepeatTimes = originK / SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t VcgFoldRepeat = CeilDivision(srcM, FLOAT_REPEAT_SIZE);
    uint16_t e2bRep = IsSameType<T2, float>::value ? srcM / DEFAULT_BLK_NUM : mRepeatTimes;
    uint16_t dtypeRepStride = IsSameType<T2, half>::value ? HALF_REPEAT_SIZE : FLOAT_REPEAT_SIZE;
    uint16_t dtypeBlkStride = dtypeRepStride / DEFAULT_BLK_NUM;
    
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregCnt;
    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL8>();
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> inMaxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> dstVreg;
    MicroAPI::RegTensor<T1> t1Reg;

    // reducemax
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);
        }
        MicroAPI::ReduceMaxWithDataBlock(maxVreg, maxVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, maxVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            maxVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregFull);
        if constexpr (sizeof(T2) == 4) {
            MicroAPI::DataCopy(workUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregFull);
        } else {
            MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
                workUb + i * HALF_REPEAT_SIZE, maxVreg, maxVreg, pregFull);
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    uint32_t sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadE2B<float>(maxVreg, workUb + i * DEFAULT_BLK_NUM);
        LoadIfNeedCast<T2>(inMaxVreg, inMaxUb + i * FLOAT_REPEAT_SIZE, pregCnt);
        MicroAPI::Max(maxVreg, inMaxVreg, maxVreg, pregCnt);
        StoreIfNeedCast<T2>(maxUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregCnt);
        if constexpr (sizeof(T2) == 4) {
            MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
                tmpUb + i * HALF_REPEAT_SIZE, maxVreg, maxVreg, pregFull);
        } else {
            MicroAPI::DataCopy(tmpUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregFull);
        }
        MicroAPI::FusedExpSub(dstVreg, inMaxVreg, maxVreg, pregCnt);
        MicroAPI::DataCopy(expMaxF32Ub + i * FLOAT_REPEAT_SIZE, dstVreg, pregCnt);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t j = 0; j < kRepeatTimes; ++j) {
        sreg = originM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
        for (uint16_t i = 0; i < mRepeatTimes; ++i) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            MicroAPI::DataCopy(maxVreg, tmpUb + i * FLOAT_REPEAT_SIZE);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregCnt);
            StoreIfNeedCast<T1>(dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            }
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    // reducesum
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(sumVreg, 0);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(tmpVreg, expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            } else {
                MicroAPI::DataCopy(tmpVreg, dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            }
            MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        }
        MicroAPI::ReduceSumWithDataBlock(sumVreg, sumVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, sumVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            sumVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        if constexpr (sizeof(T2) == 4) {
            MicroAPI::DataCopy(workUb + i * FLOAT_REPEAT_SIZE, sumVreg, pregFull);
        } else {
            MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
                workUb + i * HALF_REPEAT_SIZE, sumVreg, sumVreg, pregFull);
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadE2B<float>(sumVreg, workUb + i * DEFAULT_BLK_NUM);
        LoadIfNeedCast<T2>(tmpVreg, inExpSumUb + i * FLOAT_REPEAT_SIZE, pregCnt);
        MicroAPI::DataCopy(maxVreg, expMaxF32Ub + i * FLOAT_REPEAT_SIZE);
        MicroAPI::Mul(dstVreg, maxVreg, tmpVreg, pregCnt);
        MicroAPI::Add(sumVreg, sumVreg, dstVreg, pregCnt);
        StoreIfNeedCast<T2>(expSumUb + i * FLOAT_REPEAT_SIZE, sumVreg, pregCnt);
        if constexpr (sizeof(T1) == 2 && sizeof(T2) == 4) {
            MicroAPI::Cast<T1, float, Internal::castTraitB32ToB16>(t1Reg, maxVreg, pregCnt);
            MicroAPI::Pack<uint16_t, uint32_t>((MicroAPI::RegTensor<uint16_t> &)t1Reg,
                (MicroAPI::RegTensor<uint32_t> &)t1Reg);
            MicroAPI::DataCopy<T1, MicroAPI::StoreDist::DIST_INTLV_B16>(
                expMaxUb + i * HALF_REPEAT_SIZE, t1Reg, t1Reg, pregCnt);
        } else if constexpr (sizeof(T1) == 2 && sizeof(T2) == 2) {
            StoreIfNeedCast<T1>(expMaxUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregCnt);
        }
    }
}

template <typename T1, typename T2>
__aicore__ inline void SoftmaxFlashV2NZUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<T1>& expMaxTensor,
    const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t originK = originalSrcShape.k;
    uint16_t offset;
    if constexpr (sizeof(T2) == 4) {
        offset = CeilDivision(srcM, FLOAT_REPEAT_SIZE) * FLOAT_REPEAT_SIZE;
    } else {
        offset = CeilDivision(srcM * SOFTMAX_COMPUTE_DIM, HALF_REPEAT_SIZE) * HALF_REPEAT_SIZE;
    }
    uint16_t offset1 = srcM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;

    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* inExpSumUb = (__local_mem__ T2*)inExpSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T2* inMaxUb = (__local_mem__ T2*)inMaxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T1* expMaxUb = (__local_mem__ T1*)expMaxTensor.GetPhyAddr();
    __local_mem__ float* expMaxF32Ub = (__local_mem__ float*)expMaxTensor.GetPhyAddr();
    __local_mem__ float* workUb = (__local_mem__ float*)workLocal.GetPhyAddr();
    __local_mem__ float* tmpUb = (__local_mem__ float*)workLocal.GetPhyAddr(offset);
    __local_mem__ float* expUb = (__local_mem__ float*)workLocal.GetPhyAddr(offset + offset1);
    if constexpr (sizeof(T1) == 2) {
        expMaxF32Ub = (__local_mem__ float*)workLocal.GetPhyAddr(offset + offset1 + srcM * originK);
    }

    SoftmaxFlashV2NZUpdateVFImpl<T1, T2>(dstUb, expSumUb, inExpSumUb, maxUb, inMaxUb, srcUb, expMaxUb,
        expMaxF32Ub, workUb, tmpUb, expUb, originalSrcShape, tiling);
}

template <typename T1, typename T2>
__simd_vf__ inline void SoftmaxFlashV2NZWithTailUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* expSumUb, __local_mem__ T2* inExpSumUb,
    __local_mem__ T2* maxUb, __local_mem__ T2* inMaxUb, __local_mem__ T1* srcUb, __local_mem__ T1* expMaxUb,
    __local_mem__ float* expMaxF32Ub, __local_mem__ float* workUb, __local_mem__ float* tmpUb, __local_mem__ float* expUb,
    const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t originK = originalSrcShape.k;
    uint16_t originM = originalSrcShape.m;
    uint16_t dataBlock = srcM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t mRepeatTimes = CeilDivision(dataBlock, FLOAT_REPEAT_SIZE);
    uint16_t kRepeatTimes = originK / SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint16_t VcgFoldRepeat = CeilDivision(srcM, FLOAT_REPEAT_SIZE);
    uint16_t e2bRep = IsSameType<T2, float>::value ? srcM / DEFAULT_BLK_NUM : mRepeatTimes;
    uint16_t dtypeRepStride = IsSameType<T2, half>::value ? HALF_REPEAT_SIZE : FLOAT_REPEAT_SIZE;
    uint16_t dtypeBlkStride = dtypeRepStride / DEFAULT_BLK_NUM;
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregDst;
    MicroAPI::MaskReg pregkTail = MicroAPI::MoveMask<uint32_t>();
    MicroAPI::MaskReg pregCnt;
    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL8>();
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> inMaxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> minVreg;
    MicroAPI::RegTensor<float> dstVreg;
    MicroAPI::RegTensor<T1> t1Reg;

    // reducemax
    Duplicate(minVreg, notNum.f);
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);
        }
        LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, pregFull);
        MicroAPI::Select(srcVreg, srcVreg, minVreg, pregkTail);
        MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);

        MicroAPI::ReduceMaxWithDataBlock(maxVreg, maxVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, maxVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            maxVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregFull);
        if constexpr (sizeof(T2) == 4) {
            MicroAPI::DataCopy(workUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregFull);
        } else {
            MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
                workUb + i * HALF_REPEAT_SIZE, maxVreg, maxVreg, pregFull);
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    uint32_t sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadE2B<float>(maxVreg, workUb + i * DEFAULT_BLK_NUM);
        LoadIfNeedCast<T2>(inMaxVreg, inMaxUb + i * FLOAT_REPEAT_SIZE, pregCnt);
        MicroAPI::Max(maxVreg, inMaxVreg, maxVreg, pregCnt);
        StoreIfNeedCast<T2>(maxUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregCnt);
        if constexpr (sizeof(T2) == 4) {
            MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
                tmpUb + i * HALF_REPEAT_SIZE, maxVreg, maxVreg, pregFull);
        } else {
            MicroAPI::DataCopy(tmpUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregFull);
        }
        MicroAPI::FusedExpSub(dstVreg, inMaxVreg, maxVreg, pregCnt);
        MicroAPI::DataCopy(expMaxF32Ub + i * FLOAT_REPEAT_SIZE, dstVreg, pregCnt);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t j = 0; j < kRepeatTimes; ++j) {
        sreg = originM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
        for (uint16_t i = 0; i < mRepeatTimes; ++i) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            MicroAPI::DataCopy(maxVreg, tmpUb + i * FLOAT_REPEAT_SIZE);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregCnt);
            StoreIfNeedCast<T1>(dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock, tmpVreg, pregCnt);
            }
        }
    }
    sreg = originM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        MicroAPI::DataCopy(maxVreg, tmpUb + i * FLOAT_REPEAT_SIZE);
        LoadIfNeedCast<T1>(srcVreg, srcUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, pregFull);
        MicroAPI::MaskAnd(pregDst, pregkTail, pregCnt, pregFull);
        MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregDst);
        StoreIfNeedCast<T1>(dstUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, tmpVreg, pregDst);
        if constexpr (sizeof(T1) == 2) {
            MicroAPI::DataCopy(expUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock, tmpVreg, pregDst);
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    // reducesum
    Duplicate(minVreg, 0);
    for (uint16_t i = 0; i < mRepeatTimes; ++i) {
        Duplicate(sumVreg, 0);
        for (uint16_t j = 0; j < kRepeatTimes; ++j) {
            if constexpr (sizeof(T1) == 2) {
                MicroAPI::DataCopy(tmpVreg, expUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            } else {
                MicroAPI::DataCopy(tmpVreg, dstUb + i * FLOAT_REPEAT_SIZE + j * dataBlock);
            }
            MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        }
        if constexpr (sizeof(T1) == 2) {
            MicroAPI::DataCopy(tmpVreg, expUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock);
        } else {
            MicroAPI::DataCopy(tmpVreg, dstUb + i * FLOAT_REPEAT_SIZE + kRepeatTimes * dataBlock);
        }
        MicroAPI::Select(tmpVreg, tmpVreg, minVreg, pregkTail);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);

        MicroAPI::ReduceSumWithDataBlock(sumVreg, sumVreg, pregFull);
        MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_NORM>(workUb + i * DEFAULT_BLK_NUM, sumVreg, pregOneBlk);
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    for (uint16_t i = 0; i < VcgFoldRepeat; ++i) {
        MicroAPI::DataCopy<float, MicroAPI::LoadDist::DIST_DINTLV_B32>(
            sumVreg, tmpVreg, workUb + i * HALF_REPEAT_SIZE);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        if constexpr (sizeof(T2) == 4) {
            MicroAPI::DataCopy(workUb + i * FLOAT_REPEAT_SIZE, sumVreg, pregFull);
        } else {
            MicroAPI::DataCopy<float, MicroAPI::StoreDist::DIST_INTLV_B32>(
                workUb + i * HALF_REPEAT_SIZE, sumVreg, sumVreg, pregFull);
        }
    }

    MicroAPI::LocalMemBar<MicroAPI::MemType::VEC_STORE, MicroAPI::MemType::VEC_LOAD>();

    sreg = originM * dtypeBlkStride;
    for (uint16_t i = 0; i < e2bRep; ++i) {
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadE2B<float>(sumVreg, workUb + i * DEFAULT_BLK_NUM);
        LoadIfNeedCast<T2>(tmpVreg, inExpSumUb + i * FLOAT_REPEAT_SIZE, pregCnt);
        MicroAPI::DataCopy(maxVreg, expMaxF32Ub + i * FLOAT_REPEAT_SIZE);
        MicroAPI::Mul(dstVreg, maxVreg, tmpVreg, pregCnt);
        MicroAPI::Add(sumVreg, sumVreg, dstVreg, pregCnt);
        StoreIfNeedCast<T2>(expSumUb + i * FLOAT_REPEAT_SIZE, sumVreg, pregCnt);
        if constexpr (sizeof(T1) == 2 && sizeof(T2) == 4) {
            MicroAPI::Cast<T1, float, Internal::castTraitB32ToB16>(t1Reg, maxVreg, pregCnt);
            MicroAPI::Pack<uint16_t, uint32_t>((MicroAPI::RegTensor<uint16_t> &)t1Reg,
                (MicroAPI::RegTensor<uint32_t> &)t1Reg);
            MicroAPI::DataCopy<T1, MicroAPI::StoreDist::DIST_INTLV_B16>(
                expMaxUb + i * HALF_REPEAT_SIZE, t1Reg, t1Reg, pregCnt);
        } else if constexpr (sizeof(T1) == 2 && sizeof(T2) == 2) {
            StoreIfNeedCast<T1>(expMaxUb + i * FLOAT_REPEAT_SIZE, maxVreg, pregCnt);
        }
    }
}

template <typename T1, typename T2>
__aicore__ inline void SoftmaxFlashV2NZWithTailUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<T1>& expMaxTensor,
    const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    uint16_t srcM = tiling.srcM;
    uint16_t srcK = tiling.srcK;
    uint16_t originK = originalSrcShape.k;

    uint16_t offset;
    if constexpr (sizeof(T2) == 4) {
        offset = CeilDivision(srcM, FLOAT_REPEAT_SIZE) * FLOAT_REPEAT_SIZE;
    } else {
        offset = CeilDivision(srcM * SOFTMAX_COMPUTE_DIM, HALF_REPEAT_SIZE) * HALF_REPEAT_SIZE;
    }
    uint16_t offset1 = srcM * SOFTMAX_SHAPE_NZ_BASIC_COUNT;
    uint64_t mask[2] = {0, 0};
    CreateSpecialFormatMask(mask[0], originK % SOFTMAX_SHAPE_NZ_BASIC_COUNT, FLOAT_REPEAT_SIZE / SOFTMAX_SHAPE_NZ_BASIC_COUNT);
    SetVectorMask<uint32_t>(mask[1], mask[0]);

    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* inExpSumUb = (__local_mem__ T2*)inExpSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T2* inMaxUb = (__local_mem__ T2*)inMaxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T1* expMaxUb = (__local_mem__ T1*)expMaxTensor.GetPhyAddr();
    __local_mem__ float* expMaxF32Ub = (__local_mem__ float*)expMaxTensor.GetPhyAddr();
    __local_mem__ float* workUb = (__local_mem__ float*)workLocal.GetPhyAddr();
    __local_mem__ float* tmpUb = (__local_mem__ float*)workLocal.GetPhyAddr(offset);
    __local_mem__ float* expUb = (__local_mem__ float*)workLocal.GetPhyAddr(offset + offset1);
    if constexpr (sizeof(T1) == 2) {
        expMaxF32Ub = (__local_mem__ float*)workLocal.GetPhyAddr(offset + offset1 + srcM * srcK);
    }
    SoftmaxFlashV2NZWithTailUpdateVFImpl<T1, T2>(dstUb, expSumUb, inExpSumUb, maxUb, inMaxUb, srcUb, expMaxUb,
        expMaxF32Ub, workUb, tmpUb, expUb, originalSrcShape, tiling);
}

template <typename T1, typename T2, bool isUpdate = false, bool isBasicBlock = false>
__aicore__ inline void SoftMaxFlashV2NZImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<T1>& expMaxTensor,
    const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    static_assert((SupportType<T1, float>() && SupportType<T2, float>()) ||
                  (SupportType<T1, half>() && SupportType<T2, half>()) ||
                  (SupportType<T1, half>() && SupportType<T2, float>()),
                  "SoftMaxFlashV2 api only support half/float on current device");
    if constexpr (!isUpdate) {
        if (originalSrcShape.k != tiling.srcK) {
            SoftmaxFlashV2NZWithTailNoUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor,
                workLocal, originalSrcShape, tiling);
        } else {
            SoftmaxFlashV2NZNoUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor,
                workLocal, originalSrcShape, tiling);
        }
    } else {
        if (originalSrcShape.k != tiling.srcK) {
            SoftmaxFlashV2NZWithTailUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor,
                expMaxTensor, inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
        } else {
            SoftmaxFlashV2NZUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor, expMaxTensor,
                inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
        }
    }
}

template <typename T1, typename T2>
__simd_vf__ inline void SoftmaxFlashV2NDUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* expSumUb, __local_mem__ T2* inExpSumUb,
    __local_mem__ T2* maxUb, __local_mem__ T2* inMaxUb, __local_mem__ T1* srcUb, __local_mem__ T1* expMaxUb,
    const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = originalSrcShape.m;
    uint16_t srcK = tiling.srcK;
    uint16_t originK = originalSrcShape.k;
    uint16_t repeatTimes = CeilDivision(originK, FLOAT_REPEAT_SIZE);
    uint16_t blockStride = IsSameType<T2, half>::value ? B16_DATA_NUM_PER_BLOCK
                                                       : B32_DATA_NUM_PER_BLOCK;
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOneBlk;
    if constexpr (IsSameType<T2, half>::value) {
        pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL16>();
    } else {
        pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL8>();
    }
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> expMaxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> dstVreg;
    MicroAPI::RegTensor<T1> t1Reg;

    for (uint16_t i = 0; i < srcM; ++i) {
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < repeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);
        }
        MicroAPI::ReduceMax(maxVreg, maxVreg, pregFull);
        Duplicate(maxVreg, maxVreg, pregOneBlk);
        LoadIfNeedCast<T2>(tmpVreg, inMaxUb + i * blockStride, pregOneBlk);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregOneBlk);
        StoreIfNeedCast<T2>(maxUb + i * blockStride, maxVreg, pregOneBlk);

        MicroAPI::FusedExpSub(expMaxVreg, tmpVreg, maxVreg, pregOneBlk);
        if constexpr (sizeof(T1) == 2 && sizeof (T2) == 4) {
            MicroAPI::Cast<T1, float, Internal::castTraitB32ToB16>(t1Reg, expMaxVreg, pregOneBlk);
            MicroAPI::Pack<uint16_t, uint32_t>((MicroAPI::RegTensor<uint16_t> &)t1Reg,
                (MicroAPI::RegTensor<uint32_t> &)t1Reg);
            MicroAPI::DataCopy<T1, MicroAPI::StoreDist::DIST_INTLV_B16>(
                expMaxUb + i * blockStride * 2, t1Reg, t1Reg, pregOneBlk);
        } else {
            StoreIfNeedCast<T1>(expMaxUb + i * blockStride, expMaxVreg, pregOneBlk);
        }

        Duplicate(sumVreg, 0);
        Duplicate(maxVreg, maxVreg, pregFull);
        for (uint16_t j = 0; j < repeatTimes; ++j) {
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregFull);
            StoreIfNeedCast<T1>(dstUb + i * srcK + j * FLOAT_REPEAT_SIZE, tmpVreg, pregFull);
            MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        }
        MicroAPI::ReduceSum(sumVreg, sumVreg, pregFull);
        Duplicate(sumVreg, sumVreg, pregOneBlk);
        LoadIfNeedCast<T2>(tmpVreg, inExpSumUb + i * blockStride, pregOneBlk);
        MicroAPI::Mul(tmpVreg, expMaxVreg, tmpVreg, pregOneBlk);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregOneBlk);
        StoreIfNeedCast<T2>(expSumUb + i * blockStride, sumVreg, pregOneBlk);
    }
}

template <typename T1, typename T2>
__aicore__ inline void SoftmaxFlashV2NDUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<T1>& expMaxTensor,
    const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* inExpSumUb = (__local_mem__ T2*)inExpSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T2* inMaxUb = (__local_mem__ T2*)inMaxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T1* expMaxUb = (__local_mem__ T1*)expMaxTensor.GetPhyAddr();

    SoftmaxFlashV2NDUpdateVFImpl<T1, T2>(dstUb, expSumUb, inExpSumUb, maxUb, inMaxUb, srcUb,
        expMaxUb, originalSrcShape, tiling);
}

template <typename T1, typename T2>
__simd_vf__ inline void SoftmaxFlashV2NDWithTailUpdateVFImpl(__local_mem__ T1* dstUb, __local_mem__ T2* expSumUb, __local_mem__ T2* inExpSumUb,
    __local_mem__ T2* maxUb, __local_mem__ T2* inMaxUb, __local_mem__ T1* srcUb, __local_mem__ T1* expMaxUb,
    const LastAxisShapeND originalSrcShape, const SoftMaxTiling tiling)
{
    uint16_t srcM = originalSrcShape.m;
    uint16_t srcK = tiling.srcK;
    uint16_t originK = originalSrcShape.k;
    uint16_t repeatTimes = CeilDivision(originK, FLOAT_REPEAT_SIZE);
    uint16_t blockStride = IsSameType<T2, half>::value ? B16_DATA_NUM_PER_BLOCK
                                                       : B32_DATA_NUM_PER_BLOCK;
    NotNumUnion notNum;
    notNum.i = F32_NEG_INF;

    MicroAPI::MaskReg pregCnt;
    MicroAPI::MaskReg pregFull = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::ALL>();
    MicroAPI::MaskReg pregOneBlk;
    if constexpr (IsSameType<T2, half>::value) {
        pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL16>();
    } else {
        pregOneBlk = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::VL8>();
    }
    MicroAPI::RegTensor<float> srcVreg;
    MicroAPI::RegTensor<float> maxVreg;
    MicroAPI::RegTensor<float> expMaxVreg;
    MicroAPI::RegTensor<float> sumVreg;
    MicroAPI::RegTensor<float> tmpVreg;
    MicroAPI::RegTensor<float> minVreg;
    MicroAPI::RegTensor<float> dstVreg;
    MicroAPI::RegTensor<T1> t1Reg;

    Duplicate(minVreg, notNum.f);
    for (uint16_t i = 0; i < srcM; ++i) {
        uint32_t sreg = originK;
        Duplicate(maxVreg, notNum.f);
        for (uint16_t j = 0; j < static_cast<uint16_t>(repeatTimes - 1); ++j) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregCnt);
        }
        pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
        LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + (repeatTimes - 1) * FLOAT_REPEAT_SIZE, pregFull);
        MicroAPI::Select(srcVreg, srcVreg, minVreg, pregCnt);
        MicroAPI::Max(maxVreg, maxVreg, srcVreg, pregFull);

        MicroAPI::ReduceMax(maxVreg, maxVreg, pregFull);
        Duplicate(maxVreg, maxVreg, pregOneBlk);
        LoadIfNeedCast<T2>(tmpVreg, inMaxUb + i *blockStride, pregOneBlk);
        MicroAPI::Max(maxVreg, maxVreg, tmpVreg, pregOneBlk);
        StoreIfNeedCast<T2>(maxUb + i * blockStride, maxVreg, pregOneBlk);

        MicroAPI::FusedExpSub(expMaxVreg, tmpVreg, maxVreg, pregOneBlk);
        if constexpr (sizeof(T1) == 2 && sizeof (T2) == 4) {
            MicroAPI::Cast<T1, float, Internal::castTraitB32ToB16>(t1Reg, expMaxVreg, pregOneBlk);
            MicroAPI::Pack<uint16_t, uint32_t>((MicroAPI::RegTensor<uint16_t> &)t1Reg,
                (MicroAPI::RegTensor<uint32_t> &)t1Reg);
            MicroAPI::DataCopy<T1, MicroAPI::StoreDist::DIST_INTLV_B16>(
                expMaxUb + i * blockStride * 2, t1Reg, t1Reg, pregOneBlk);
        } else {
            StoreIfNeedCast<T1>(expMaxUb + i * blockStride, expMaxVreg, pregOneBlk);
        }

        Duplicate(sumVreg, 0);
        Duplicate(maxVreg, maxVreg, pregFull);
        sreg = originK;
        for (uint16_t j = 0; j < repeatTimes; ++j) {
            pregCnt = MicroAPI::UpdateMask<uint32_t>(sreg);
            LoadIfNeedCast<T1>(srcVreg, srcUb + i * srcK + j * FLOAT_REPEAT_SIZE, pregFull);
            MicroAPI::FusedExpSub(tmpVreg, srcVreg, maxVreg, pregCnt);
            StoreIfNeedCast<T1>(dstUb + i * srcK + j * FLOAT_REPEAT_SIZE, tmpVreg, pregCnt);
            MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregFull);
        }
        MicroAPI::ReduceSum(sumVreg, sumVreg, pregFull);
        Duplicate(sumVreg, sumVreg, pregOneBlk);
        LoadIfNeedCast<T2>(tmpVreg, inExpSumUb + i * blockStride, pregOneBlk);
        MicroAPI::Mul(tmpVreg, expMaxVreg, tmpVreg, pregOneBlk);
        MicroAPI::Add(sumVreg, sumVreg, tmpVreg, pregOneBlk);
        StoreIfNeedCast<T2>(expSumUb + i * blockStride, sumVreg, pregOneBlk);
    }
}

template <typename T1, typename T2>
__aicore__ inline void SoftmaxFlashV2NDWithTailUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<T1>& expMaxTensor,
    const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    __local_mem__ T1* dstUb = (__local_mem__ T1*)dstTensor.GetPhyAddr();
    __local_mem__ T2* expSumUb = (__local_mem__ T2*)expSumTensor.GetPhyAddr();
    __local_mem__ T2* inExpSumUb = (__local_mem__ T2*)inExpSumTensor.GetPhyAddr();
    __local_mem__ T2* maxUb = (__local_mem__ T2*)maxTensor.GetPhyAddr();
    __local_mem__ T2* inMaxUb = (__local_mem__ T2*)inMaxTensor.GetPhyAddr();
    __local_mem__ T1* srcUb = (__local_mem__ T1*)srcTensor.GetPhyAddr();
    __local_mem__ T1* expMaxUb = (__local_mem__ T1*)expMaxTensor.GetPhyAddr();

    SoftmaxFlashV2NDWithTailUpdateVFImpl<T1, T2>(dstUb, expSumUb, inExpSumUb, maxUb, inMaxUb, srcUb,
        expMaxUb, originalSrcShape, tiling);
}

template <typename T1, typename T2, bool isBasicBlock = false, bool outputBrc = true>
__aicore__ inline void SoftmaxFlashV2NDNoUpdateImpl(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    if constexpr (isBasicBlock) {
        SoftMaxGenericNDImpl<T1, T2, true, false, outputBrc>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
            originalSrcShape, tiling);
    } else {
        if (tiling.srcK == B32_DATA_NUM_PER_BLOCK && IsSameType<T1, float>::value) {
            SingleSoftMaxGenericNDForBlkImpl<T1, T2, true, false, outputBrc>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        } else if (tiling.srcK == B32_DATA_NUM_PER_BLOCK * 2) {
            SingleSoftMaxGenericNDAlignedWithBlkImpl<T1, T2, true, false, outputBrc>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        } else if (originalSrcShape.k <= FLOAT_REPEAT_SIZE) {
            SingleSoftMaxGenericNDImpl<T1, T2, true, false, outputBrc>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        } else if (originalSrcShape.k % FLOAT_REPEAT_SIZE != 0) {
            SoftMaxGenericNDWithTailImpl<T1, T2, true, false, outputBrc>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        } else {
            SoftMaxGenericNDImpl<T1, T2, true, false, outputBrc>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        }
    }
}

template <typename T1, typename T2, bool isUpdate = false, bool isBasicBlock = false, bool isOutputReduceMax = false>
__aicore__ inline void SoftmaxFlashV2M1PostProcess(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& outReduceMax,
    const LocalTensor<T2>& expSumTensor, const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor,
    const LocalTensor<T1>& expMaxTensor, const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor,
    const LocalTensor<float>& workLocal, const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    static_assert((SupportType<T1, float>() && SupportType<T2, float>()) ||
                  (SupportType<T1, half>() && SupportType<T2, half>()) ||
                  (SupportType<T1, half>() && SupportType<T2, float>()),
                  "SoftMaxFlashV2 api only support half/float on current device");
    if constexpr (isBasicBlock && isUpdate) {
        SoftmaxFlashV2M1NDUpdateImpl<T1, T2, isOutputReduceMax>(dstTensor, outReduceMax, expSumTensor, maxTensor, srcTensor, expMaxTensor,
            inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
    } else {
        if constexpr (!isUpdate) {
            SoftmaxFlashV2NDNoUpdateImpl<T1, T2, isBasicBlock, false>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        } else {
            if (originalSrcShape.k % FLOAT_REPEAT_SIZE) {
                SoftmaxFlashV2M1NDWithTailUpdateImpl<T1, T2, isOutputReduceMax>(dstTensor, outReduceMax, expSumTensor, maxTensor, srcTensor,
                    expMaxTensor, inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
            } else {
                SoftmaxFlashV2M1NDUpdateImpl<T1, T2, isOutputReduceMax>(dstTensor, outReduceMax, expSumTensor, maxTensor, srcTensor, expMaxTensor,
                    inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
            }
        }
    }
}

template <typename T1, typename T2, bool isUpdate = false, bool isBasicBlock = false, const SoftmaxConfig& config>
__aicore__ inline void SoftmaxFlashV2PostProcess(const LocalTensor<T1>& dstTensor, const LocalTensor<T2>& expSumTensor,
    const LocalTensor<T2>& maxTensor, const LocalTensor<T1>& srcTensor, const LocalTensor<T1>& expMaxTensor,
    const LocalTensor<T2>& inExpSumTensor, const LocalTensor<T2>& inMaxTensor, const LocalTensor<float>& workLocal,
    const LastAxisShapeND& originalSrcShape, const SoftMaxTiling& tiling)
{
    static_assert((SupportType<T1, float>() && SupportType<T2, float>()) ||
                  (SupportType<T1, half>() && SupportType<T2, half>()) ||
                  (SupportType<T1, half>() && SupportType<T2, float>()),
                  "SoftMaxFlashV2 api only support half/float on current device");
    if constexpr (isBasicBlock && isUpdate) {
        SoftmaxFlashV2NDUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor, expMaxTensor,
            inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
    } else {
        if constexpr (!isUpdate) {
            SoftmaxFlashV2NDNoUpdateImpl<T1, T2, isBasicBlock>(dstTensor, expSumTensor, maxTensor, srcTensor, workLocal,
                originalSrcShape, tiling);
        } else {
            if (originalSrcShape.k % FLOAT_REPEAT_SIZE) {
                SoftmaxFlashV2NDWithTailUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor,
                    expMaxTensor, inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
            } else {
                SoftmaxFlashV2NDUpdateImpl<T1, T2>(dstTensor, expSumTensor, maxTensor, srcTensor, expMaxTensor,
                    inExpSumTensor, inMaxTensor, workLocal, originalSrcShape, tiling);
            }
        }
    }
}
}
#endif // IMPL_ACTIVATION_SOFTMAX_C310_SOFTMAX_FLASHV2_IMPL_H
